{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4b4eb8c0-6156-4d2b-86c7-8546daa6632a",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "94e9d30d-9e89-4ee8-800a-6f440228144a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from networks.inception import InceptionResnet, SupIncepResnet\n",
    "from networks.simple_cnn import BaselineCNNClassifier\n",
    "from networks.resnet_big import SupCEResNet, SupConResNet, LinearClassifier\n",
    "import torch\n",
    "from torchsummary import summary\n",
    "from thop import profile\n",
    "import numpy as np\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "398ddce7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def measure_time_gpu(model, device, rep):\n",
    "    model = model.to(device=device)\n",
    "    dummy_input = torch.randn(1, 1, 29, 29, dtype=torch.float).to(device)\n",
    "    # INIT LOGGERS\n",
    "    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)\n",
    "    repetitions = rep\n",
    "    timings=np.zeros((repetitions,1))\n",
    "    #GPU-WARM-UP\n",
    "    for _ in range(100):\n",
    "        _ = model(dummy_input)\n",
    "    # MEASURE PERFORMANCE\n",
    "    with torch.no_grad():\n",
    "        for rep in range(repetitions):\n",
    "            starter.record()\n",
    "            _ = model(dummy_input)\n",
    "            ender.record()\n",
    "            # WAIT FOR GPU SYNC\n",
    "            torch.cuda.synchronize()\n",
    "            curr_time = starter.elapsed_time(ender)\n",
    "            timings[rep] = curr_time\n",
    "    mean_syn = np.sum(timings) / repetitions\n",
    "    std_syn = np.std(timings)\n",
    "    return mean_syn, std_syn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f08bd61e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def measure_time_cpu(model, device, rep = 10):\n",
    "    model = model.to(device=device)\n",
    "    x = torch.rand((1, 1, 29, 29), device=device)\n",
    "    timings=np.zeros((rep,1))\n",
    "    for i in range(rep):    \n",
    "        start_time = time.time()\n",
    "        out = model(x)\n",
    "        timings[i] = time.time() - start_time\n",
    "    mean_syn = np.sum(timings) / rep\n",
    "    std_syn = np.std(timings)\n",
    "    return mean_syn, std_syn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "46511c12-302e-412d-9558-1c30d6598fd8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5.565476221561432, 0.2572142879522647)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "incep = SupIncepResnet(num_classes=5)\n",
    "measure_time_gpu(incep, 'cuda', rep=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b35e749d-bf36-4c56-a6a3-47e6732abf5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5.958114630699158, 0.3380396613151823)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "baseline_model = SupCEResNet(name='resnet18', num_classes=5)\n",
    "measure_time_gpu(baseline_model, 'cuda', rep=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec6b048f-724f-4617-8650-d46c4daf0575",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n",
      "[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
      "[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.\n",
      "[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.\n",
      "MACs (G):  0.032560592\n",
      "Params (M):  0.700533\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand((1, 1, 29, 29), device='cpu')\n",
    "baseline_model = baseline_model.to(device='cpu')\n",
    "macs, params = profile(baseline_model, inputs=(x, ))\n",
    "print('MACs (G): ', macs/1000**3)\n",
    "print('Params (M): ', params/1000**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "31daf0f2-281a-44e8-aede-c7d57a1205df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
      "[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n",
      "[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.\n",
      "[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.\n",
      "MACs (G):  0.097190176\n",
      "Params (M):  1.694181\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand((1, 1, 29, 29), device='cpu')\n",
    "incep = incep.to(device='cpu')\n",
    "macs, params = profile(incep, inputs=(x, ))\n",
    "print('MACs (G): ', macs/1000**3)\n",
    "print('Params (M): ', params/1000**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "cf5bac89-5eee-4c5c-b616-c7ddd465c953",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 16, 29, 29]             144\n",
      "       BatchNorm2d-2           [-1, 16, 29, 29]              32\n",
      "            Conv2d-3           [-1, 16, 29, 29]           2,304\n",
      "       BatchNorm2d-4           [-1, 16, 29, 29]              32\n",
      "            Conv2d-5           [-1, 16, 29, 29]           2,304\n",
      "       BatchNorm2d-6           [-1, 16, 29, 29]              32\n",
      "        BasicBlock-7           [-1, 16, 29, 29]               0\n",
      "            Conv2d-8           [-1, 16, 29, 29]           2,304\n",
      "       BatchNorm2d-9           [-1, 16, 29, 29]              32\n",
      "           Conv2d-10           [-1, 16, 29, 29]           2,304\n",
      "      BatchNorm2d-11           [-1, 16, 29, 29]              32\n",
      "       BasicBlock-12           [-1, 16, 29, 29]               0\n",
      "           Conv2d-13           [-1, 32, 15, 15]           4,608\n",
      "      BatchNorm2d-14           [-1, 32, 15, 15]              64\n",
      "           Conv2d-15           [-1, 32, 15, 15]           9,216\n",
      "      BatchNorm2d-16           [-1, 32, 15, 15]              64\n",
      "           Conv2d-17           [-1, 32, 15, 15]             512\n",
      "      BatchNorm2d-18           [-1, 32, 15, 15]              64\n",
      "       BasicBlock-19           [-1, 32, 15, 15]               0\n",
      "           Conv2d-20           [-1, 32, 15, 15]           9,216\n",
      "      BatchNorm2d-21           [-1, 32, 15, 15]              64\n",
      "           Conv2d-22           [-1, 32, 15, 15]           9,216\n",
      "      BatchNorm2d-23           [-1, 32, 15, 15]              64\n",
      "       BasicBlock-24           [-1, 32, 15, 15]               0\n",
      "           Conv2d-25             [-1, 64, 8, 8]          18,432\n",
      "      BatchNorm2d-26             [-1, 64, 8, 8]             128\n",
      "           Conv2d-27             [-1, 64, 8, 8]          36,864\n",
      "      BatchNorm2d-28             [-1, 64, 8, 8]             128\n",
      "           Conv2d-29             [-1, 64, 8, 8]           2,048\n",
      "      BatchNorm2d-30             [-1, 64, 8, 8]             128\n",
      "       BasicBlock-31             [-1, 64, 8, 8]               0\n",
      "           Conv2d-32             [-1, 64, 8, 8]          36,864\n",
      "      BatchNorm2d-33             [-1, 64, 8, 8]             128\n",
      "           Conv2d-34             [-1, 64, 8, 8]          36,864\n",
      "      BatchNorm2d-35             [-1, 64, 8, 8]             128\n",
      "       BasicBlock-36             [-1, 64, 8, 8]               0\n",
      "           Conv2d-37            [-1, 128, 4, 4]          73,728\n",
      "      BatchNorm2d-38            [-1, 128, 4, 4]             256\n",
      "           Conv2d-39            [-1, 128, 4, 4]         147,456\n",
      "      BatchNorm2d-40            [-1, 128, 4, 4]             256\n",
      "           Conv2d-41            [-1, 128, 4, 4]           8,192\n",
      "      BatchNorm2d-42            [-1, 128, 4, 4]             256\n",
      "       BasicBlock-43            [-1, 128, 4, 4]               0\n",
      "           Conv2d-44            [-1, 128, 4, 4]         147,456\n",
      "      BatchNorm2d-45            [-1, 128, 4, 4]             256\n",
      "           Conv2d-46            [-1, 128, 4, 4]         147,456\n",
      "      BatchNorm2d-47            [-1, 128, 4, 4]             256\n",
      "       BasicBlock-48            [-1, 128, 4, 4]               0\n",
      "AdaptiveAvgPool2d-49            [-1, 128, 1, 1]               0\n",
      "           ResNet-50                  [-1, 128]               0\n",
      "           Linear-51                    [-1, 5]             645\n",
      "================================================================\n",
      "Total params: 700,533\n",
      "Trainable params: 700,533\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 2.46\n",
      "Params size (MB): 2.67\n",
      "Estimated Total Size (MB): 5.13\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(baseline_model.cuda(), (1, 29, 29))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "56eb30d1-60ce-4944-b917-2a012b1f49fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 32, 29, 29]             320\n",
      "            Conv2d-2           [-1, 32, 27, 27]           9,248\n",
      "         MaxPool2d-3           [-1, 32, 13, 13]               0\n",
      "            Conv2d-4           [-1, 64, 13, 13]           2,112\n",
      "            Conv2d-5          [-1, 128, 13, 13]          73,856\n",
      "            Conv2d-6          [-1, 128, 13, 13]         147,584\n",
      "              Stem-7          [-1, 128, 13, 13]               0\n",
      "            Conv2d-8           [-1, 32, 13, 13]           4,128\n",
      "            Conv2d-9           [-1, 32, 13, 13]           4,128\n",
      "           Conv2d-10           [-1, 32, 13, 13]           9,248\n",
      "           Conv2d-11           [-1, 32, 13, 13]           4,128\n",
      "           Conv2d-12           [-1, 32, 13, 13]           9,248\n",
      "           Conv2d-13           [-1, 32, 13, 13]           9,248\n",
      "           Conv2d-14          [-1, 128, 13, 13]          12,416\n",
      "             ReLU-15          [-1, 128, 13, 13]               0\n",
      "InceptionresenetA-16          [-1, 128, 13, 13]               0\n",
      "        MaxPool2d-17            [-1, 128, 6, 6]               0\n",
      "           Conv2d-18            [-1, 192, 6, 6]         221,376\n",
      "           Conv2d-19           [-1, 96, 13, 13]          12,384\n",
      "           Conv2d-20           [-1, 96, 13, 13]          83,040\n",
      "           Conv2d-21            [-1, 128, 6, 6]         110,720\n",
      "       ReductionA-22            [-1, 448, 6, 6]               0\n",
      "           Conv2d-23             [-1, 64, 6, 6]          28,736\n",
      "           Conv2d-24             [-1, 64, 6, 6]          28,736\n",
      "           Conv2d-25             [-1, 64, 6, 6]          12,352\n",
      "           Conv2d-26             [-1, 64, 6, 6]          12,352\n",
      "           Conv2d-27            [-1, 448, 6, 6]          57,792\n",
      "             ReLU-28            [-1, 448, 6, 6]               0\n",
      " InceptionresnetB-29            [-1, 448, 6, 6]               0\n",
      "        MaxPool2d-30            [-1, 448, 2, 2]               0\n",
      "           Conv2d-31            [-1, 128, 6, 6]          57,472\n",
      "           Conv2d-32            [-1, 192, 2, 2]         221,376\n",
      "           Conv2d-33            [-1, 128, 6, 6]          57,472\n",
      "           Conv2d-34            [-1, 128, 2, 2]         147,584\n",
      "           Conv2d-35            [-1, 128, 6, 6]          57,472\n",
      "           Conv2d-36            [-1, 128, 4, 4]         147,584\n",
      "           Conv2d-37            [-1, 128, 2, 2]         147,584\n",
      "       ReductionB-38            [-1, 896, 2, 2]               0\n",
      "AdaptiveAvgPool2d-39            [-1, 896, 1, 1]               0\n",
      "        Dropout2d-40            [-1, 896, 1, 1]               0\n",
      "  InceptionResnet-41                  [-1, 896]               0\n",
      "           Linear-42                    [-1, 5]           4,485\n",
      "================================================================\n",
      "Total params: 1,694,181\n",
      "Trainable params: 1,694,181\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 2.87\n",
      "Params size (MB): 6.46\n",
      "Estimated Total Size (MB): 9.34\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(incep.cuda(), (1, 29, 29))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "ec8a7a313ab33d199c8aa698bb86bd912b8385ce4922a6e184e3f5edd5eb95f6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
